{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Error in cpuinfo: prctl(PR_SVE_GET_VL) failed\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from d2l import torch as d2l\n",
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "true_w = torch.tensor([2, -3.4])\n",
    "true_b = 4.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustDatasetForRegression(torch.utils.data.Dataset):\n",
    "    def __init__(self, true_w, true_b, num_samples):\n",
    "        self.true_w = true_w\n",
    "        self.true_b = true_b\n",
    "        self.num_samples = num_samples\n",
    "\n",
    "        self.features, self.labels = d2l.synthetic_data(true_w, true_b, num_samples)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = {\"inputs\": self.features[idx], \"labels\": self.labels[idx]}\n",
    "        return item\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = CustDatasetForRegression(true_w, true_b, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomModelForRegression(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(CustomModelForRegression, self).__init__()\n",
    "        self.net = nn.Sequential(nn.Linear(2, 1))\n",
    "\n",
    "    def forward(self, inputs, labels=None):\n",
    "        logits = self.net(inputs)\n",
    "\n",
    "        if labels is not None:\n",
    "            loss_fn = nn.MSELoss()\n",
    "            loss = loss_fn(logits, labels)\n",
    "            return {\"logits\": logits, \"loss\": loss}\n",
    "        else:\n",
    "            return {\"logits\": logits}\n",
    "\n",
    "    @classmethod\n",
    "    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):\n",
    "        model = cls(*model_args, **kwargs)\n",
    "        state_dict = torch.load(f\"{pretrained_model_name_or_path}/model.bin\")\n",
    "        model.load_state_dict(state_dict)\n",
    "        return model\n",
    "\n",
    "    def save_pretrained(self, save_directory):\n",
    "        self.to(torch.device(\"cpu\"))\n",
    "        torch.save(self.state_dict(), f\"{save_directory}/model.bin\")\n",
    "\n",
    "    def predict(self, inputs, device):\n",
    "        device = device or self.device\n",
    "        with torch.no_grad():\n",
    "            inputs = inputs.to(device)\n",
    "            out = self(inputs, None)\n",
    "            return out[\"logits\"].flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CustomModelForRegression()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CustomModelForRegression(\n",
       "  (net): Sequential(\n",
       "    (0): Linear(in_features=2, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='40' max='40' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [40/40 00:00, Epoch 20/20]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>31.782100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>21.786800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>14.083900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>8.247200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>4.171000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>1.636300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.325800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.008600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.000300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.000100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=40, training_loss=4.102151058428717, metrics={'train_runtime': 0.0821, 'train_samples_per_second': 243729.476, 'train_steps_per_second': 487.459, 'total_flos': 0.0, 'train_loss': 4.102151058428717, 'epoch': 20.0})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./results\",\n",
    "    num_train_epochs=20,\n",
    "    logging_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=512,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,  # 自定义模型实例\n",
    "    args=training_args,  # 训练参数\n",
    "    train_dataset=data,\n",
    "    optimizers=(optimizer, None),  # 传递优化器\n",
    ")\n",
    "\n",
    "trainer.train()  # 开始训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"./model/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_30289/3466821424.py:19: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  state_dict = torch.load(f\"{pretrained_model_name_or_path}/model.bin\")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "CustomModelForRegression(\n",
       "  (net): Sequential(\n",
       "    (0): Linear(in_features=2, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.from_pretrained(\"./model/\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[ 1.9996, -3.4005]], requires_grad=True) \n",
      " Parameter containing:\n",
      "tensor([4.2004], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "print(model.net[0].weight, \"\\n\", model.net[0].bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2.0019, -7.6055])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(torch.tensor([[2.0, 3.0], [6.0, 7.0]]), torch.device(\"cpu\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "unlock-hf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
